#!/usr/bin/env python
# coding: utf-8

## Transcription and diarization of audio files with Whisper and pyannote
## Francesco Garassino
## 
## This code is adapted from that provided by `riteshere` on [GitHub](https://github.com/riteshhere/Speaker_diarization/blob/5d39ef36dd7c4c20099be278c7c5cf86a043174b/research_files/speech_Diarization.ipynb) and explained on 
## [Medium](https://medium.com/@xriteshsharmax/speaker-diarization-using-whisper-asr-and-pyannote-f0141c85d59a).
##


# -- imports --

import whisper
import datetime

import subprocess

import torch
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding

from pyannote.audio import Audio
from pyannote.core import Segment

import wave
import contextlib

from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import time
import random
import pandas as pd

import argparse


# -- functions --

def parse_args():
	"""
	Parse command line arguments for the audio transcription and diarization script.

	Returns: The parsed command line arguments: input_file, language, num_speakers, and model_size.
	"""

	# Create the parser
	parser = argparse.ArgumentParser(description="This script performs automated transcription and diarization of an audio file.")

	# Add arguments
	parser.add_argument('-i', '--input_file', type=str, help='The audio file path', required=True)
	parser.add_argument('-o', '--output_dir', type=str, help='The folder in which the outputs will be stored', required=True)
	parser.add_argument('-l', '--language', type=str, choices = ['any', 'English'], help='The language of the audio file', required=True)
	parser.add_argument('-n', '--num_speakers', type=int, help='The number of speakers/voices present in the audio', required=True)
	parser.add_argument('-s', '--model_size', type=str, choices=['tiny', 'base', 'small', 'medium', 'large-v2'], help='The "size", i.e. the type, of model to use with Whisper', required=True)

	# Parse the arguments
	args = parser.parse_args()
	 
	return args.input_file, args.output_dir, args.language, args.num_speakers, args.model_size

def segment_embedding(segment, duration):
	"""
	Compute the speaker embedding for a given audio segment.
	
	Parameters:
	segment (dict): A dictionary containing the start and end times of an audio segment, as well as the .
	duration (float): The duration in seconds of the audio file each segment comes from.
        
	Returns:
	embedding_model: A numpy array containing the computed embeddings characterizing the speaker of the segment.
	"""

	start = segment["start"]
    
	# Whisper overshoots the end timestamp in the last segment, so when dealing with the last segment, we will use the calculated duration
	# of the audio file instead of the segment end
	end = min(duration, segment["end"])
	
	# Segment is a class imported from pyannote.core. A Segment object represents a time interval with a start time and an end time.
	clip = Segment(start, end)

	# We then can use the Audio.crop() method from pyannote.audio to extract the clip we just defined, corresponding to a segment, 
	# from our audio file (specified by 'path'). Audio.crop() returns two objects when extracting a segment from an audio file:

	# Waveform: a numpy array containing the audio data for the specified segment. The shape of the array is typically (num_samples, num_channels), 
	# where num_samples is the number of audio samples in the segment and num_channels is the number of audio channels (e.g., 1 for mono, 2 for stereo).

	# Sample Rate: we have already introduced this, and will not need it here.
	waveform, sample_rate = audio.crop(path, clip)

	# Finally, we use the embedding_model to compute embeddings for the extracted waveform. 
	# segment_embedding() will return a numpy array characterizing the speaker of a segment.
	return embedding_model(waveform[None])

# pretty formatting of timestamps
def time_secs(secs):
	return datetime.timedelta(seconds=round(secs))


# -- main --

if __name__ == '__main__':
	
	# Run the parser and access the arguments
	path, out_dir, language, num_speakers, model_size = parse_args()

	# Make sure the out_dir path ends in a '/'
	if out_dir[-1] != '/':
		out_dir += '/'

	model_name = model_size
	
	if language == 'English' and model_size != 'large-v2':
  		model_name += '.en'

	print('-------------')
	print(f'Processing file {path.split('/')[-1]}...\n')

	# Whisper only transcribes from `.waw` audio files. We can use `ffmpeg` to convert our
	# audio file to .waw, if needed:

	if path.split('.')[-1] != 'wav':
		print('Converting file to WAW format...\n')
		# option -i is for specifying the filename, -y for overwriting of output files
		subprocess.call(['ffmpeg', '-i', path, f'{''.join(path.split('.')[:-1])}.wav', '-y'])
    
    	# this, also used in the call above, will create a path with the same name as original but with .waw extension
	path = f'{''.join(path.split('.')[:-1])}.wav' 


	# Transcription
	# We can now import the whisper model we specified above:

	print('Running Whisper...\n')

	model = whisper.load_model(model_size, device="cpu")
	# Note that with 'device', we could have whisper run on a GPU as well - though this does not appear to come with
	# improved performance. See https://github.com/ggerganov/whisper.cpp/issues/1540


	# Time to run whisper on our file!
	 
	# model.transcribe returns a dictionary with three key-value pairs: 
	 
	# 1. the transcribed text as a single string
	# 2. the transcription "segments", i.e. the sentences identified by whisper, complete with timestamps
	# 3. the detected language (when using the 'large' family models

	start_time = time.time()

	result = model.transcribe(path)
	segments = result["segments"]

	end_time = time.time()

	print(f"Elapsed time: {round((end_time - start_time)/60)} minutes\n")

	# We now need to extract and calculate essential information from our .waw file:
	
	# 1. The total number of frames, i.e. the number of single units of audio data that make up our audio file
	# 2. The frame rate, also known as the sample rate, i.e. the number of frames contained in a second of the audio file
	# 3. The calculated duration of the audio file in seconds
	
	# We'll leverage the `wawe` module for opening the .waw file and retrieving the information.

	# contextlib.closing(wave.open(path, 'r')) is used to open a WAV file for reading in such a way that it ensures the file 
	# is properly closed after its use, even if an error occurs.

	with contextlib.closing(wave.open(path,'r')) as f:
  		frames = f.getnframes()
  		rate = f.getframerate()
  		duration = frames / float(rate)


	# Diarization

	print('Running diarization...\n')
	 
	# Our first step towards diarization is to retrieve speaker *embeddings* from each of the segments (i.e., sentences) that whisper extracted from the audio file.
	 
	# *Speaker embeddings* are numerical representations of a person's voice. They capture the unique characteristics of a speaker's voice and will allow us to 
	# tell different speakers apart in our audio file.
	
	# Here, we will use a function in a loop to retrieve embeddings for all the segments in our audio file. This function leverages functionalities from the 
	# `pyannote.audio` module. 

	# Audio is a class imported from pyannote.audio that offers functionality for loading audio files, extracting features, 
	# and manipulating audio data
	audio = Audio() 

	# PretrainedSpeakerEmbedding utilizes a model that has been pre-trained on a large dataset of speech recordings. 
	# The primary function of this class is to extract speaker embeddings from audio segments. 

	embedding_model = PretrainedSpeakerEmbedding(
    		"speechbrain/spkrec-ecapa-voxceleb",
    		device = torch.device("mps")) 
	# the "device" specification can be used to direct pyannote.audio processes to the GPU, making it up to 10x faster. 
	# Note that you may need to change this depending on your system: https://pytorch.org/docs/stable/tensor_attributes.html#torch.device

	# Let's now iterate over our segments and retrieve all embeddingss

	embeddings = np.zeros(shape=(len(segments), 192)) # creates a bi-dimensional array of size #segments*192 and fills it with zeros

	for i, segment in enumerate(segments):
  		embeddings[i] = segment_embedding(segment, duration)

	embeddings = np.nan_to_num(embeddings)

	# Essentially, we have created an array in which each row corresponds to one of the segments of our transcript, and contains 192 values describing
	# the characteritics of the voice present in that segment.


	# Now we can *cluster* all our embeddings based on their values across the 193 columns. For this, we will perform [agglomerative hierarchical clustering]
	# (https://en.wikipedia.org/wiki/Hierarchical_clustering) with a function imported from the `sklearn.cluster` module.
 
	# Having specified the number of speakers, we will force the algorithm to cluster the embeddings into as many clusters as there are speakers.

	clustering = AgglomerativeClustering(num_speakers).fit(embeddings)

	# the embeddings are assigned a numerical label from [0,num_speakers)
	labels = clustering.labels_

	# with this, we will add a "speaker" key-value pair to the dictionary corresponding to each segment in the 'segments' list.
	# Note the 'labels[i] + 1', making sure there is no "SPEAKER 0" label
	for i in range(len(segments)):
  		segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)


	# All there is left to do is to print our labelled segments (i.e., sentences) to a transcript `.txt` file
	
	print('Creating transcript file...\n')
	
	# will create an identical filename as the audio file, but with .txt suffix
	out_path = f'{out_dir}{''.join(path.split('/')[-1].split('.')[:-1])}_diarization.txt'


	# this chunk mostly contains formatting operations whose in-detail explanation would be overkill
	f = open(out_path, 'w')

	x = "" # we will also store all labels and segments into a string
	
	for (i, segment) in enumerate(segments):
  		if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
    			f.write("\n" + segment["speaker"] + ' ' + str(time_secs(segment["start"])) + '\n')
  		f.write(segment["text"][1:] + ' ')
  		x += "\n" + segment["speaker"] + ' ' + str(time_secs(segment["start"])) + '\n'
  		x += segment["text"][1:] + ' '
	f.close()

	# We can get a first impression of how the clustering worked by making a PCA plot of the clustered embeddings:

	print('Making a PCA plot of clustering of segments...\n')

	# Perform PCA to reduce the dimensionality of embeddings to 2D
	pca = PCA(n_components=2, random_state=42)
	embeddings_2d = pca.fit_transform(embeddings)

	# Get the number of unique speakers from the labels
	num_unique_speakers = len(np.unique(labels))

	# Create a colormap for speakers, ensuring each speaker gets a unique color
	colors = cm.tab20b(np.linspace(0, 1, num_unique_speakers))

	# Plot the clusters
	plt.figure(figsize=(10, 8))
	for i, segment in enumerate(segments):
    		speaker_id = labels[i] + 1
    		x, y = embeddings_2d[i]
    		color = colors[labels[i] % num_unique_speakers]  # Get the corresponding color for the speaker
    		plt.scatter(x, y, label=f'SPEAKER {speaker_id}', color=color)

	# making the legend more user-friendly
	handles, labs = plt.gca().get_legend_handles_labels()
	
	# zip labels as keys and handles as values into a dictionary, so only unique labels will be stored 
	dict_of_labels = dict(zip(labs, handles))
	
	# use unique labels (dict_of_labels.keys()) to generate legend
	plt.legend(dict_of_labels.values(), dict_of_labels.keys())

	plt.title("Speaker Diarization Clusters (PCA Visualization)")
	plt.xlabel("Principal Component 1")
	plt.ylabel("Principal Component 2")

	plot_path = f'{out_dir}{''.join(path.split('/')[-1].split('.')[:-1])}_diarization_PCA.png'
	plt.savefig(plot_path, bbox_inches='tight')
	
	print('Done.\n')
